{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # in google colab uncomment this\n",
    "\n",
    "# import os\n",
    "\n",
    "# os.system('apt-get install -y xvfb')\n",
    "# os.system('wget https://raw.githubusercontent.com/yandexdataschool/Practical_DL/fall18/xvfb -O ../xvfb')\n",
    "# os.system('apt-get install -y python-opengl ffmpeg')\n",
    "# os.system('pip install pyglet==1.2.4')\n",
    "\n",
    "# launch XVFB if you run on a server\n",
    "import os\n",
    "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n",
    "    !bash ../xvfb start\n",
    "    os.environ['DISPLAY'] = ':1'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's make a TRPO!\n",
    "\n",
    "In this notebook we will write the code of the one Trust Region Policy Optimization.\n",
    "As usually, it contains a few different parts which we are going to reproduce.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.\u001b[0m\n",
      "Observation Space Box(6,)\n",
      "Action Space Discrete(3)\n"
     ]
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "env = gym.make(\"Acrobot-v1\")\n",
    "env.reset()\n",
    "observation_shape = env.observation_space.shape\n",
    "n_actions = env.action_space.n\n",
    "print(\"Observation Space\", env.observation_space)\n",
    "print(\"Action Space\", env.action_space)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f3c2dc56dd8>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQsAAAD8CAYAAABgtYFHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAADjJJREFUeJzt3X3I3Wd9x/H3Z+mDborpw70Qkkgqhkn/2GpzUyvKcC2O2onpH1VaZAYJBDYHFQcu3WBD2B+6P6wKQw2rLA617XygoXRzXVoZ+8PaO/bBPqz2rrQ0oZqobd0Q3arf/XGu6DGmua8793lM3i84nOt3/a7fOd9TTj69fr9znXOnqpCklfzGtAuQNB8MC0ldDAtJXQwLSV0MC0ldDAtJXcYSFkmuSvJ4kuUke8bxHJImK6NeZ5FkHfBt4K3AIeA+4PqqenSkTyRposYxs7gMWK6q71TV/wK3ADvG8DySJuisMTzmJuCZoe1DwBtOdsCFF15YW7duHUMpko45ePDg96tq4VSPH0dYdEmyG9gN8OpXv5qlpaVplSKdEZI8vZbjx3EachjYMrS9ufX9iqraW1WLVbW4sHDKYSdpQsYRFvcB25JclOQc4Dpg/xieR9IEjfw0pKpeTPJnwFeBdcBnquqRUT+PpMkayzWLqroTuHMcjy1pOlzBKamLYSGpi2EhqYthIamLYSGpi2EhqYthIamLYSGpi2EhqYthIamLYSGpi2EhqYthIamLYSGpi2EhqYthIamLYSGpi2EhqYthIamLYSGpi2EhqYthIamLYSGpi2EhqYthIamLYSGpi2EhqYthIamLYSGpi2EhqYthIamLYSGpi2EhqYthIamLYSGpy4phkeQzSY4keXio7/wkdyV5ot2f1/qT5BNJlpM8lOTScRYvaXJ6Zhb/CFx1XN8e4EBVbQMOtG2AtwHb2m038MnRlClp2lYMi6r6D+CHx3XvAPa19j7gmqH+z9bA14H1STaOqlhJ03Oq1yw2VNWzrf1dYENrbwKeGRp3qPX9miS7kywlWTp69OgpliFpUtZ8gbOqCqhTOG5vVS1W1eLCwsJay5A0ZqcaFt87dnrR7o+0/sPAlqFxm1ufpDl3qmGxH9jZ2juB24f639M+FbkceGHodEXSHDtrpQFJvgC8BbgwySHgb4APA7cl2QU8DbyrDb8TuBpYBn4MvHcMNUuaghXDoqquf4ldV55gbAHvW2tRkmaPKzgldTEsJHUxLCR1MSwkdTEsJHUxLCR1MSwkdTEsJHUxLCR1yWDR5ZSLSKZfhHT6O1hVi6d68IrLvSdh+/btLC0tTbsM6bSWZE3HexoiqYthIamLYSGpi2EhqYthIamLYSGpi2EhqYthIamLYSGpi2EhqYthIamLYSGpi2EhqYthIamLYSGpi2EhqYthIamLYSGpi2EhqYthIamLYSGpi2EhqYthIanLimGRZEuSe5I8muSRJDe0/vOT3JXkiXZ/XutPkk8kWU7yUJJLx/0iJI1fz8ziReDPq+pi4HLgfUkuBvYAB6pqG3CgbQO8DdjWbruBT468akkTt2JYVNWzVfXN1v5v4DFgE7AD2NeG7QOuae0dwGdr4OvA+iQbR165pIla1TWLJFuB1wP3Ahuq6tm267vAhtbeBDwzdNih1idpjnWHRZJXAF8C3l9VPxreV4O/rryqP26cZHeSpSRLR48eXc2hkqagKyySnM0gKD5XVV9u3d87dnrR7o+0/sPAlqHDN7e+X1FVe6tqsaoWFxYWTrV+SRPS82lIgJuBx6rqo0O79gM7W3sncPtQ/3vapyKXAy8Mna5ImlNndYx5E/DHwLeSPND6/hL4MHBbkl3A08C72r47gauBZeDHwHtHWrGkqVgxLKrqP4G8xO4rTzC+gPetsS5JM8YVnJK6GBaSuhgWkroYFpK6GBaSuhgWkroYFpK6GBaSuhgWkroYFpK6GBaSuvR8kUz6hYMHf/VrQtu3r+pnTDTHnFmo2/FB8VJ9Oj0ZFupyslAwMM4MhoVW1BMGBsbpz7CQ1MWwkNTFsJDUxbDQihZZGskYzTfDQl1OFgYGxZnBsFC3E4WCQXHmcAWnVsVwOHM5s5DUxbCQ1MWwkNTFsJDUxbCQ1MWwkNTFsJDUxbCQ1MWwkNTFsJDUxbCQ1MWwkNTFsJDUZcWwSPKyJN9I8mCSR5J8qPVflOTeJMtJbk1yTus/t20vt/1bx/sSJE1Cz8zip8AVVfV7wCXAVUkuBz4C3FRVrwWeA3a18buA51r/TW2cpDm3YljUwP+0zbPbrYArgC+2/n3ANa29o23T9l+ZxN+Jl+Zc1zWLJOuSPAAcAe4CngSer6oX25BDwKbW3gQ8A9D2vwBccILH3J1kKcnS0aNH1/YqNHW1ffu0S9CYdYVFVf2sqi4BNgOXAa9b6xNX1d6qWqyqxYWFhbU+nKQxW9WnIVX1PHAP8EZgfZJjP8u3GTjc2oeBLQBt/6uAH4ykWklT0/NpyEKS9a39cuCtwGMMQuPaNmwncHtr72/btP13V5V/aluacz0/2LsR2JdkHYNwua2q7kjyKHBLkr8F7gdubuNvBv4pyTLwQ+C6MdQtacJWDIuqegh4/Qn6v8Pg+sXx/T8B3jmS6iTNDFdwSupiWOikcvDgtEvQjDAsJHUxLCR1MSwkdTEsJHUxLCR1MSwkdTEsJHUxLCR1MSwkdTEsJHUxLCR1MSwkdTEsJHUxLCR1MSy0KkssssTitMvQFPT8rJ70awFxbHuRpWmUoylwZqEVnWwm4SzjzGFYSOpiWOikemYOzi7ODIaFpC6GhaQuhoVOqufTDj8ROTMYFpK6GBZa0clmDs4qzhwuytJJ1fbt5OBBQ0HOLDQa/uWy059hIamLYSGpi2EhqYthIamLYSGpi2EhqUt3WCRZl+T+JHe07YuS3JtkOcmtSc5p/ee27eW2f+t4Spc0SauZWdwAPDa0/RHgpqp6LfAcsKv17wKea/03tXGS5lxXWCTZDPwR8A9tO8AVwBfbkH3ANa29o23T9l/ZxkuaY70zi48BHwR+3rYvAJ6vqhfb9iFgU2tvAp4BaPtfaOMlzbEVwyLJ24EjVTXS9bxJdidZSrJ09OjRUT60Rqy2b592CZoBPTOLNwHvSPIUcAuD04+PA+uTHPsi2mbgcGsfBrYAtP2vAn5w/INW1d6qWqyqxYWFhTW9CM0Gvx9yelsxLKrqxqraXFVbgeuAu6vq3cA9wLVt2E7g9tbe37Zp+++uqhpp1ZImbi3rLP4C+ECSZQbXJG5u/TcDF7T+DwB71laipFmwqt+zqKqvAV9r7e8Al51gzE+Ad46gNkkzxBWckroYFpK6GBaSuhgWkroYFpK6GBaSuhgW6uKSbxkWGimXfJ++DAtJXQwLSV0MC0ldDAtJXQwLSV0MC0ldDAtJXQwLSV0MC0ldDAtJXQwLSV0MC0ldDAtJXQwLSV0MC0ldDAtJXQwLdfPXss5shoVGzl/LOj0ZFpK6GBaSuqzqDyNLXrc4czmzkNTFsJDUxbCQ1MWwkNTFsJDUxbCQ1KUrLJI8leRbSR5IstT6zk9yV5In2v15rT9JPpFkOclDSS4d5wuQNBmrmVn8QVVdUlWLbXsPcKCqtgEH2jbA24Bt7bYb+OSoipU0PWs5DdkB7GvtfcA1Q/2frYGvA+uTbFzD80iaAb0rOAv4tyQFfLqq9gIbqurZtv+7wIbW3gQ8M3Tsodb37FAfSXYzmHkA/DTJw6dQ/7RcCHx/2kV0mqdaYb7qnadaAX5nLQf3hsWbq+pwkt8G7kryX8M7q6pakHRrgbMXIMnS0OnNzJuneuepVpiveuepVhjUu5bju05Dqupwuz8CfAW4DPjesdOLdn+kDT8MbBk6fHPrkzTHVgyLJL+V5JXH2sAfAg8D+4GdbdhO4PbW3g+8p30qcjnwwtDpiqQ51XMasgH4SpJj4z9fVf+a5D7gtiS7gKeBd7XxdwJXA8vAj4H3djzH3tUWPmXzVO881QrzVe881QprrDdVq7rUIOkM5QpOSV2mHhZJrkryeFvxuWflI8Zez2eSHBn+KHeWV6sm2ZLkniSPJnkkyQ2zWnOSlyX5RpIHW60fav0XJbm31XRrknNa/7lte7nt3zqpWodqXpfk/iR3zEGt411pXVVTuwHrgCeB1wDnAA8CF0+5pt8HLgUeHur7O2BPa+8BPtLaVwP/AgS4HLh3CvVuBC5t7VcC3wYunsWa23O+orXPBu5tNdwGXNf6PwX8SWv/KfCp1r4OuHUK/30/AHweuKNtz3KtTwEXHtc3svfBRF/MCV7cG4GvDm3fCNw4zZpaHVuPC4vHgY2tvRF4vLU/DVx/onFTrP124K2zXjPwm8A3gTcwWNh01vHvCeCrwBtb+6w2LhOscTODrzJcAdzR/mHNZK3teU8UFiN7H0z7NOSlVnvOmtWuVp2KNvV9PYP/Y89kzW1a/wCDdTl3MZhZPl9VL56gnl/U2va/AFwwqVqBjwEfBH7eti9gdmuFX660PthWSMMI3wf+YO8qVa1+teokJHkF8CXg/VX1o/ZRNzBbNVfVz4BLkqxnsMDvdVMu6YSSvB04UlUHk7xl2vV0GvlK62HTnlnMy2rPmV6tmuRsBkHxuar6cuue6Zqr6nngHgZT+fVJjv2Pa7ieX9Ta9r8K+MGESnwT8I4kTwG3MDgV+fiM1gqMf6X1tMPiPmBbu8J8DoMLQ/unXNOJzOxq1QymEDcDj1XVR4d2zVzNSRbajIIkL2dwbeUxBqFx7UvUeuw1XAvcXe0Ee9yq6saq2lxVWxm8L++uqnfPYq0woZXWk7wA8xIXZa5mcAX/SeCvZqCeLzD4huz/MTiP28Xg3PMA8ATw78D5bWyAv2+1fwtYnEK9b2ZwrvoQ8EC7XT2LNQO/C9zfan0Y+OvW/xrgGwxW/f4zcG7rf1nbXm77XzOl98Rb+OWnITNZa6vrwXZ75Ni/pVG+D1zBKanLtE9DJM0Jw0JSF8NCUhfDQlIXw0JSF8NCUhfDQlIXw0JSl/8Huhr8fpmXAZ4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f3c941cfe80>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "plt.imshow(env.render('rgb_array'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1: Defining a network\n",
    "\n",
    "With all it's complexity, at it's core TRPO is yet another policy gradient method. \n",
    "\n",
    "This essentially means we're actually training a stochastic policy $ \\pi_\\theta(a|s) $. \n",
    "\n",
    "And yes, it's gonna be a neural network. So let's start by defining one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TRPOAgent(nn.Module):\n",
    "    def __init__(self, state_shape, n_actions, hidden_size=32):\n",
    "        '''\n",
    "        Here you should define your model\n",
    "        You should have LOG-PROBABILITIES as output because you will need it to compute loss\n",
    "        We recommend that you start simple: \n",
    "        use 1-2 hidden layers with 100-500 units and relu for the first try\n",
    "        '''\n",
    "        nn.Module.__init__(self)\n",
    "\n",
    "        <your network here >\n",
    "        self.model = None\n",
    "\n",
    "    def forward(self, states):\n",
    "        \"\"\"\n",
    "        takes agent's observation (Variable), returns log-probabilities (Variable)\n",
    "        :param state_t: a batch of states, shape = [batch_size, state_shape]\n",
    "        \"\"\"\n",
    "\n",
    "        # Use your network to compute log_probs for given state\n",
    "        log_probs = self.model(states)\n",
    "        return log_probs\n",
    "\n",
    "    def get_log_probs(self, states):\n",
    "        '''\n",
    "        Log-probs for training\n",
    "        '''\n",
    "\n",
    "        return self.forward(states)\n",
    "\n",
    "    def get_probs(self, states):\n",
    "        '''\n",
    "        Probs for interaction\n",
    "        '''\n",
    "\n",
    "        return torch.exp(self.forward(states))\n",
    "\n",
    "    def act(self, obs, sample=True):\n",
    "        '''\n",
    "        Samples action from policy distribution (sample = True) or takes most likely action (sample = False)\n",
    "        :param: obs - single observation vector\n",
    "        :param sample: if True, samples from \\pi, otherwise takes most likely action\n",
    "        :returns: action (single integer) and probabilities for all actions\n",
    "        '''\n",
    "\n",
    "        probs = self.get_probs(Variable(torch.FloatTensor([obs]))).data.numpy()\n",
    "\n",
    "        if sample:\n",
    "            action = int(np.random.choice(n_actions, p=probs[0]))\n",
    "        else:\n",
    "            action = int(np.argmax(probs))\n",
    "\n",
    "        return action, probs[0]\n",
    "\n",
    "\n",
    "agent = TRPOAgent(observation_shape, n_actions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sampled: [(2, array([0.35253003, 0.37892205, 0.26854792], dtype=float32)), (2, array([0.35269254, 0.37673423, 0.27057323], dtype=float32)), (0, array([0.35406563, 0.37682924, 0.26910514], dtype=float32)), (0, array([0.3560282 , 0.37561142, 0.2683604 ], dtype=float32)), (1, array([0.35539204, 0.37685862, 0.26774937], dtype=float32))]\n",
      "greedy: [(1, array([0.3518883 , 0.37830737, 0.2698043 ], dtype=float32)), (1, array([0.3544095 , 0.37609497, 0.26949552], dtype=float32)), (1, array([0.35528135, 0.37493262, 0.269786  ], dtype=float32)), (1, array([0.3589018 , 0.37457928, 0.26651892], dtype=float32)), (1, array([0.35414994, 0.3769723 , 0.26887777], dtype=float32))]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/container.py:67: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  input = module(input)\n"
     ]
    }
   ],
   "source": [
    "# Check if log-probabilities satisfies all the requirements\n",
    "log_probs = agent.get_log_probs(Variable(torch.FloatTensor([env.reset()])))\n",
    "assert isinstance(\n",
    "    log_probs, Variable) and log_probs.requires_grad, \"qvalues must be a torch variable with grad\"\n",
    "assert len(\n",
    "    log_probs.shape) == 2 and log_probs.shape[0] == 1 and log_probs.shape[1] == n_actions\n",
    "sums = torch.sum(torch.exp(log_probs), dim=1)\n",
    "assert (0.999 < sums).all() and (1.001 > sums).all()\n",
    "\n",
    "# Demo use\n",
    "print(\"sampled:\", [agent.act(env.reset()) for _ in range(5)])\n",
    "print(\"greedy:\", [agent.act(env.reset(), sample=False) for _ in range(5)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Flat parameters operations\n",
    "\n",
    "We are going to use it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_flat_params_from(model):\n",
    "    params = []\n",
    "    for param in model.parameters():\n",
    "        params.append(param.data.view(-1))\n",
    "\n",
    "    flat_params = torch.cat(params)\n",
    "    return flat_params\n",
    "\n",
    "\n",
    "def set_flat_params_to(model, flat_params):\n",
    "    prev_ind = 0\n",
    "    for param in model.parameters():\n",
    "        flat_size = int(np.prod(list(param.size())))\n",
    "        param.data.copy_(\n",
    "            flat_params[prev_ind:prev_ind + flat_size].view(param.size()))\n",
    "        prev_ind += flat_size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute cummulative reward just like you did in vanilla REINFORCE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.signal\n",
    "\n",
    "\n",
    "def get_cummulative_returns(r, gamma=1):\n",
    "    \"\"\"\n",
    "    Computes cummulative discounted rewards given immediate rewards\n",
    "    G_i = r_i + gamma*r_{i+1} + gamma^2*r_{i+2} + ...\n",
    "    Also known as R(s,a).\n",
    "    \"\"\"\n",
    "    r = np.array(r)\n",
    "    assert r.ndim >= 1\n",
    "    return scipy.signal.lfilter([1], [1, -gamma], r[::-1], axis=0)[::-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# simple demo on rewards [0,0,1,0,0,1]\n",
    "get_cummulative_returns([0, 0, 1, 0, 0, 1], gamma=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Rollout**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rollout(env, agent, max_pathlength=2500, n_timesteps=50000):\n",
    "    \"\"\"\n",
    "    Generate rollouts for training.\n",
    "    :param: env - environment in which we will make actions to generate rollouts.\n",
    "    :param: act - the function that can return policy and action given observation.\n",
    "    :param: max_pathlength - maximum size of one path that we generate.\n",
    "    :param: n_timesteps - total sum of sizes of all pathes we generate.\n",
    "    \"\"\"\n",
    "    paths = []\n",
    "\n",
    "    total_timesteps = 0\n",
    "    while total_timesteps < n_timesteps:\n",
    "        obervations, actions, rewards, action_probs = [], [], [], []\n",
    "        obervation = env.reset()\n",
    "        for _ in range(max_pathlength):\n",
    "            action, policy = agent.act(obervation)\n",
    "            obervations.append(obervation)\n",
    "            actions.append(action)\n",
    "            action_probs.append(policy)\n",
    "            obervation, reward, done, _ = env.step(action)\n",
    "            rewards.append(reward)\n",
    "            total_timesteps += 1\n",
    "            if done or total_timesteps == n_timesteps:\n",
    "                path = {\"observations\": np.array(obervations),\n",
    "                        \"policy\": np.array(action_probs),\n",
    "                        \"actions\": np.array(actions),\n",
    "                        \"rewards\": np.array(rewards),\n",
    "                        \"cumulative_returns\": get_cummulative_returns(rewards),\n",
    "                        }\n",
    "                paths.append(path)\n",
    "                break\n",
    "    return paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "paths = rollout(env, agent, max_pathlength=5, n_timesteps=100)\n",
    "print(paths[-1])\n",
    "assert (paths[0]['policy'].shape == (5, n_actions))\n",
    "assert (paths[0]['cumulative_returns'].shape == (5,))\n",
    "assert (paths[0]['rewards'].shape == (5,))\n",
    "assert (paths[0]['observations'].shape == (5,)+observation_shape)\n",
    "assert (paths[0]['actions'].shape == (5,))\n",
    "print('It\\'s ok')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 3: Auxiliary functions\n",
    "\n",
    "Now let's define the loss functions and something else for actual TRPO training."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The surrogate reward should be\n",
    "$$J_{surr}= {1 \\over N} \\sum\\limits_{i=0}^N \\frac{\\pi_{\\theta}(s_i, a_i)}{\\pi_{\\theta_{old}}(s_i, a_i)}A_{\\theta_{old}(s_i, a_i)}$$\n",
    "\n",
    "For simplicity, let's use cummulative returns instead of advantage for now:\n",
    "$$J'_{surr}= {1 \\over N} \\sum\\limits_{i=0}^N \\frac{\\pi_{\\theta}(s_i, a_i)}{\\pi_{\\theta_{old}}(s_i, a_i)}G_{\\theta_{old}(s_i, a_i)}$$\n",
    "\n",
    "Or alternatively, minimize the surrogate loss:\n",
    "$$ L_{surr} = - J'_{surr} $$  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_loss(agent, observations, actions, cummulative_returns, old_probs):\n",
    "    \"\"\"\n",
    "    Computes TRPO objective\n",
    "    :param: observations - batch of observations\n",
    "    :param: actions - batch of actions\n",
    "    :param: cummulative_returns - batch of cummulative returns\n",
    "    :param: old_probs - batch of probabilities computed by old network\n",
    "    :returns: scalar value of the objective function\n",
    "    \"\"\"\n",
    "    batch_size = observations.shape[0]\n",
    "    log_probs_all = agent.get_log_probs(observations)\n",
    "    probs_all = torch.exp(log_probs_all)\n",
    "\n",
    "    probs_for_actions = probs_all[torch.arange(\n",
    "        0, batch_size, out=torch.LongTensor()), actions]\n",
    "    old_probs_for_actions = old_probs[torch.arange(\n",
    "        0, batch_size, out=torch.LongTensor()), actions]\n",
    "\n",
    "    # Compute surrogate loss, aka importance-sampled policy gradient\n",
    "    Loss = <compute surrogate loss >\n",
    "\n",
    "    assert Loss.shape == torch.Size([])\n",
    "    return Loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can ascend these gradients as long as our $pi_\\theta(a|s)$ satisfies the constraint\n",
    "$$E_{s,\\pi_{\\Theta_{t}}}\\Big[KL(\\pi(\\Theta_{t}, s) \\:||\\:\\pi(\\Theta_{t+1}, s))\\Big]< \\alpha$$\n",
    "\n",
    "\n",
    "where\n",
    "\n",
    "$$KL(p||q) = E _p log({p \\over q})$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_kl(agent, observations, actions, cummulative_returns, old_probs):\n",
    "    \"\"\"\n",
    "    Computes KL-divergence between network policy and old policy\n",
    "    :param: observations - batch of observations\n",
    "    :param: actions - batch of actions\n",
    "    :param: cummulative_returns - batch of cummulative returns (we don't need it actually)\n",
    "    :param: old_probs - batch of probabilities computed by old network\n",
    "    :returns: scalar value of the KL-divergence\n",
    "    \"\"\"\n",
    "    batch_size = observations.shape[0]\n",
    "    log_probs_all = agent.get_log_probs(observations)\n",
    "    probs_all = torch.exp(log_probs_all)\n",
    "\n",
    "    # Compute Kullback-Leibler divergence (see formula above)\n",
    "    # Note: you need to sum KL and entropy over all actions, not just the ones agent took\n",
    "    old_log_probs = torch.log(old_probs+1e-10)\n",
    "\n",
    "    kl = <cumpute kullback-leibler >\n",
    "\n",
    "    assert kl.shape == torch.Size([])\n",
    "    assert (kl > -0.0001).all() and (kl < 10000).all()\n",
    "    return kl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_entropy(agent, observations):\n",
    "    \"\"\"\n",
    "    Computes entropy of the network policy \n",
    "    :param: observations - batch of observations\n",
    "    :returns: scalar value of the entropy\n",
    "    \"\"\"\n",
    "\n",
    "    observations = Variable(torch.FloatTensor(observations))\n",
    "\n",
    "    batch_size = observations.shape[0]\n",
    "    log_probs_all = agent.get_log_probs(observations)\n",
    "    probs_all = torch.exp(log_probs_all)\n",
    "\n",
    "    entropy = torch.sum(-probs_all * log_probs_all) / batch_size\n",
    "\n",
    "    assert entropy.shape == torch.Size([])\n",
    "    return entropy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Linear search**\n",
    "\n",
    "TRPO in its core involves ascending surrogate policy gradient constrained by KL divergence. \n",
    "\n",
    "In order to enforce this constraint, we're gonna use linesearch. You can find out more about it [here](https://en.wikipedia.org/wiki/Linear_search)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def linesearch(f, x, fullstep, max_kl):\n",
    "    \"\"\"\n",
    "    Linesearch finds the best parameters of neural networks in the direction of fullstep contrainted by KL divergence.\n",
    "    :param: f - function that returns loss, kl and arbitrary third component.\n",
    "    :param: x - old parameters of neural network.\n",
    "    :param: fullstep - direction in which we make search.\n",
    "    :param: max_kl - constraint of KL divergence.\n",
    "    :returns:\n",
    "    \"\"\"\n",
    "    max_backtracks = 10\n",
    "    loss, _, = f(x)\n",
    "    for stepfrac in .5**np.arange(max_backtracks):\n",
    "        xnew = x + stepfrac * fullstep\n",
    "        new_loss, kl = f(xnew)\n",
    "        actual_improve = new_loss - loss\n",
    "        if kl.data.numpy() <= max_kl and actual_improve.data.numpy() < 0:\n",
    "            x = xnew\n",
    "            loss = new_loss\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Conjugate gradients**\n",
    "\n",
    "Since TRPO includes contrainted optimization, we will need to solve Ax=b using conjugate gradients.\n",
    "\n",
    "In general, CG is an algorithm that solves Ax=b where A is positive-defined. A is Hessian matrix so A is positive-defined. You can find out more about them [here](https://en.wikipedia.org/wiki/Conjugate_gradient_method)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from numpy.linalg import inv\n",
    "\n",
    "\n",
    "def conjugate_gradient(f_Ax, b, cg_iters=10, residual_tol=1e-10):\n",
    "    \"\"\"\n",
    "    This method solves system of equation Ax=b using iterative method called conjugate gradients\n",
    "    :f_Ax: function that returns Ax\n",
    "    :b: targets for Ax\n",
    "    :cg_iters: how many iterations this method should do\n",
    "    :residual_tol: epsilon for stability\n",
    "    \"\"\"\n",
    "    p = b.clone()\n",
    "    r = b.clone()\n",
    "    x = torch.zeros(b.size())\n",
    "    rdotr = torch.sum(r*r)\n",
    "    for i in range(cg_iters):\n",
    "        z = f_Ax(p)\n",
    "        v = rdotr / (torch.sum(p*z) + 1e-8)\n",
    "        x += v * p\n",
    "        r -= v * z\n",
    "        newrdotr = torch.sum(r*r)\n",
    "        mu = newrdotr / (rdotr + 1e-8)\n",
    "        p = r + mu * p\n",
    "        rdotr = newrdotr\n",
    "        if rdotr < residual_tol:\n",
    "            break\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This code validates conjugate gradients\n",
    "A = np.random.rand(8, 8)\n",
    "A = np.matmul(np.transpose(A), A)\n",
    "\n",
    "\n",
    "def f_Ax(x):\n",
    "    return torch.matmul(torch.FloatTensor(A), x.view((-1, 1))).view(-1)\n",
    "\n",
    "\n",
    "b = np.random.rand(8)\n",
    "\n",
    "w = np.matmul(np.matmul(inv(np.matmul(np.transpose(A), A)),\n",
    "                        np.transpose(A)), b.reshape((-1, 1))).reshape(-1)\n",
    "print(w)\n",
    "print(conjugate_gradient(f_Ax, torch.FloatTensor(b)).numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 4: training\n",
    "In this section we construct the whole update step function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_step(agent, observations, actions, cummulative_returns, old_probs, max_kl):\n",
    "    \"\"\"\n",
    "    This function does the TRPO update step\n",
    "    :param: observations - batch of observations\n",
    "    :param: actions - batch of actions\n",
    "    :param: cummulative_returns - batch of cummulative returns\n",
    "    :param: old_probs - batch of probabilities computed by old network\n",
    "    :param: max_kl - controls how big KL divergence may be between old and new policy every step.\n",
    "    :returns: KL between new and old policies and the value of the loss function.\n",
    "    \"\"\"\n",
    "\n",
    "    # Here we prepare the information\n",
    "    observations = Variable(torch.FloatTensor(observations))\n",
    "    actions = torch.LongTensor(actions)\n",
    "    cummulative_returns = Variable(torch.FloatTensor(cummulative_returns))\n",
    "    old_probs = Variable(torch.FloatTensor(old_probs))\n",
    "\n",
    "    # Here we compute gradient of the loss function\n",
    "    loss = get_loss(agent, observations, actions,\n",
    "                    cummulative_returns, old_probs)\n",
    "    grads = torch.autograd.grad(loss, agent.parameters())\n",
    "    loss_grad = torch.cat([grad.view(-1) for grad in grads]).data\n",
    "\n",
    "    def Fvp(v):\n",
    "        # Here we compute Fx to do solve Fx = g using conjugate gradients\n",
    "        # We actually do here a couple of tricks to compute it efficiently\n",
    "\n",
    "        kl = get_kl(agent, observations, actions,\n",
    "                    cummulative_returns, old_probs)\n",
    "\n",
    "        grads = torch.autograd.grad(kl, agent.parameters(), create_graph=True)\n",
    "        flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])\n",
    "\n",
    "        kl_v = (flat_grad_kl * Variable(v)).sum()\n",
    "        grads = torch.autograd.grad(kl_v, agent.parameters())\n",
    "        flat_grad_grad_kl = torch.cat(\n",
    "            [grad.contiguous().view(-1) for grad in grads]).data\n",
    "\n",
    "        return flat_grad_grad_kl + v * 0.1\n",
    "\n",
    "    # Here we solveolve Fx = g system using conjugate gradients\n",
    "    stepdir = conjugate_gradient(Fvp, -loss_grad, 10)\n",
    "\n",
    "    # Here we compute the initial vector to do linear search\n",
    "    shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True)\n",
    "\n",
    "    lm = torch.sqrt(shs / max_kl)\n",
    "    fullstep = stepdir / lm[0]\n",
    "\n",
    "    neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True)\n",
    "\n",
    "    # Here we get the start point\n",
    "    prev_params = get_flat_params_from(agent)\n",
    "\n",
    "    def get_loss_kl(params):\n",
    "        # Helper for linear search\n",
    "        set_flat_params_to(agent, params)\n",
    "        return [get_loss(agent, observations, actions, cummulative_returns, old_probs),\n",
    "                get_kl(agent, observations, actions, cummulative_returns, old_probs)]\n",
    "\n",
    "    # Here we find our new parameters\n",
    "    new_params = linesearch(get_loss_kl, prev_params, fullstep, max_kl)\n",
    "\n",
    "    # And we set it to our network\n",
    "    set_flat_params_to(agent, new_params)\n",
    "\n",
    "    return get_loss_kl(new_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Step 5: Main TRPO loop\n",
    "\n",
    "Here we will train our network!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from itertools import count\n",
    "from collections import OrderedDict\n",
    "\n",
    "# this is hyperparameter of TRPO. It controls how big KL divergence may be between old and new policy every step.\n",
    "max_kl = 0.01\n",
    "numeptotal = 0  # this is number of episodes that we played.\n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "for i in count(1):\n",
    "\n",
    "    print(\"\\n********** Iteration %i ************\" % i)\n",
    "\n",
    "    # Generating paths.\n",
    "    print(\"Rollout\")\n",
    "    paths = rollout(env, agent)\n",
    "    print(\"Made rollout\")\n",
    "\n",
    "    # Updating policy.\n",
    "    observations = np.concatenate([path[\"observations\"] for path in paths])\n",
    "    actions = np.concatenate([path[\"actions\"] for path in paths])\n",
    "    returns = np.concatenate([path[\"cumulative_returns\"] for path in paths])\n",
    "    old_probs = np.concatenate([path[\"policy\"] for path in paths])\n",
    "\n",
    "    loss, kl = update_step(agent, observations, actions,\n",
    "                           returns, old_probs, max_kl)\n",
    "\n",
    "    # Report current progress\n",
    "    episode_rewards = np.array([path[\"rewards\"].sum() for path in paths])\n",
    "\n",
    "    stats = OrderedDict()\n",
    "    numeptotal += len(episode_rewards)\n",
    "    stats[\"Total number of episodes\"] = numeptotal\n",
    "    stats[\"Average sum of rewards per episode\"] = episode_rewards.mean()\n",
    "    stats[\"Std of rewards per episode\"] = episode_rewards.std()\n",
    "    stats[\"Time elapsed\"] = \"%.2f mins\" % ((time.time() - start_time)/60.)\n",
    "    stats[\"KL between old and new distribution\"] = kl.data.numpy()\n",
    "    stats[\"Entropy\"] = get_entropy(agent, observations).data.numpy()\n",
    "    stats[\"Surrogate loss\"] = loss.data.numpy()\n",
    "    for k, v in stats.items():\n",
    "        print(k + \": \" + \" \" * (40 - len(k)) + str(v))\n",
    "    i += 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Homework option I: better sampling (10+pts)\n",
    "\n",
    "In this section, you're invited to implement a better rollout strategy called _vine_.\n",
    "\n",
    "![img](https://s17.postimg.cc/i90chxgvj/vine.png)\n",
    "\n",
    "In most gym environments, you can actually backtrack by using states. You can find a wrapper that saves/loads states in [the mcts seminar](https://github.com/yandexdataschool/Practical_RL/blob/master/week10_planning/seminar_MCTS.ipynb).\n",
    "\n",
    "You can read more about in the [TRPO article](https://arxiv.org/abs/1502.05477) in section 5.2.\n",
    "\n",
    "The goal here is to implement such rollout policy (we recommend using tree data structure like in the seminar above).\n",
    "Then you can assign cummulative rewards similar to `get_cummulative_rewards`, but for a tree.\n",
    "\n",
    "__bonus task__ - parallelize samples using multiple cores"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Homework option II (10+pts)\n",
    "\n",
    "Let's use TRPO to train evil robots! (pick any of two)\n",
    "* [MuJoCo robots](https://gym.openai.com/envs#mujoco)\n",
    "* [Box2d robot](https://gym.openai.com/envs/BipedalWalker-v2)\n",
    "\n",
    "The catch here is that those environments have continuous action spaces. \n",
    "\n",
    "Luckily, TRPO is a policy gradient method, so it's gonna work for any parametric $\\pi_\\theta(a|s)$. We recommend starting with gaussian policy:\n",
    "\n",
    "$$\\pi_\\theta(a|s) = N(\\mu_\\theta(s),\\sigma^2_\\theta(s)) = {1 \\over \\sqrt { 2 \\pi {\\sigma^2}_\\theta(s) } } e^{ (a - \n",
    "\\mu_\\theta(s))^2 \\over 2 {\\sigma^2}_\\theta(s) } $$\n",
    "\n",
    "In the $\\sqrt { 2 \\pi {\\sigma^2}_\\theta(s) }$ clause, $\\pi$ means ~3.1415926, not agent's policy.\n",
    "\n",
    "This essentially means that you will need two output layers:\n",
    "* $\\mu_\\theta(s)$, a dense layer with linear activation\n",
    "* ${\\sigma^2}_\\theta(s)$, a dense layer with activation tf.exp (to make it positive; like rho from bandits)\n",
    "\n",
    "For multidimensional actions, you can use fully factorized gaussian (basically a vector of gaussians).\n",
    "\n",
    "__bonus task__: compare performance of continuous action space method to action space discretization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
